import torch
import torch.distributed as dist
def loss_and_top1_acc(loss_per_gpu, top1_acc_per_gpu, local_rank):
    if local_rank is not None:
        total_loss_per_gpu = torch.tensor(loss_per_gpu.sum).to(local_rank)
        total_top1_acc_per_gpu = torch.tensor(top1_acc_per_gpu.sum).to(local_rank)
        total_sample_per_gpu = torch.tensor(loss_per_gpu.count).to(local_rank)
        dist.all_reduce(total_loss_per_gpu, op=dist.ReduceOp.SUM)
        dist.all_reduce(total_top1_acc_per_gpu, op=dist.ReduceOp.SUM)
        dist.all_reduce(total_sample_per_gpu, op=dist.ReduceOp.SUM)
        Loss = total_loss_per_gpu.item() / total_sample_per_gpu.item()
        top1_acc = total_top1_acc_per_gpu.item() / total_sample_per_gpu.item()
    else:
        Loss = loss_per_gpu.avg
        top1_acc = top1_acc_per_gpu.avg
    return Loss, top1_acc